package strategy

import (
	"context"
	"crypto/ecdsa"
	"crypto/x509"
	"encoding/json"
	"encoding/pem"
	"fmt"
	"gitee.com/lipore/plume/auth"
	"gitee.com/lipore/plume/errors"
	"gitee.com/lipore/plume/logger"
	"gitee.com/lipore/plume/redis"
	"sync"
	"time"
)

type RedisPublicKeyInfo struct {
	KeyId      string `json:"kid"`
	EncodedKey string `json:"key"`
}

type RedisPublicKeyStoreOptions func(store *redisPublicKeyStore)

func WithChannel(channel string) RedisPublicKeyStoreOptions {
	return func(store *redisPublicKeyStore) {
		store.channel = channel
	}
}

func WithKeyPrefix(keyPrefix string) RedisPublicKeyStoreOptions {
	return func(store *redisPublicKeyStore) {
		store.keyPrefix = keyPrefix
	}
}

type redisPublicKeyStore struct {
	sync.Mutex

	client redis.Client

	keyPrefix string
	channel   string

	publicKeyCache  map[string]ecdsa.PublicKey
	cacheMatchCount map[string]int64
}

func NewRedisPublicKeyStore(ctx context.Context, opts ...RedisPublicKeyStoreOptions) auth.PublicKeyStore {
	store := &redisPublicKeyStore{
		publicKeyCache:  make(map[string]ecdsa.PublicKey),
		cacheMatchCount: make(map[string]int64),
	}

	for _, opt := range opts {
		opt(store)
	}

	store.client = redis.Default()

	go func() {
		timer := time.NewTimer(24 * time.Hour)
		for {
			select {
			case <-ctx.Done():
				timer.Stop()
				return
			case <-timer.C:
				store.CleanUnusedKey()
				timer.Reset(24 * time.Hour)
			}
		}
	}()
	store.listen(ctx)
	return store
}

func (s *redisPublicKeyStore) listen(ctx context.Context) {
	err := redis.Subscribe(ctx, s.client, func(message *redis.Message) {

		kiStr := message.Payload

		ki := &RedisPublicKeyInfo{}
		err := json.Unmarshal([]byte(kiStr), ki)
		if err != nil {
			logger.Warnf("parse redis message failed: %s", err.Error())
			logger.Debugf("origin message: %s", kiStr)
		}

		key, err := s.unmarshalPublicKey(ki.EncodedKey)
		if err != nil {
			logger.Warnf("parse encoded key failed: %s", err.Error())
			logger.Debugf("encoded key: %s", ki.EncodedKey)
			return
		}
		s.Lock()
		s.publicKeyCache[ki.KeyId] = *key
		s.Unlock()
	}, redis.WithChannel(s.channel), redis.WithCheckALiveInterval(1*time.Second))
	if err != nil {
		logger.Warnf("%v", errors.WithMessage(err, "redis subscription listen failed"))
	}
}

func (s *redisPublicKeyStore) buildKey(keyId string) string {
	return fmt.Sprintf("%s_%s", s.keyPrefix, keyId)
}

func (s *redisPublicKeyStore) marshalPublicKey(key *ecdsa.PublicKey) (string, error) {
	encoded, err := x509.MarshalPKIXPublicKey(key)
	if err != nil {
		return "", err
	}
	pemData := pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: encoded})
	return string(pemData), nil
}

func (s *redisPublicKeyStore) unmarshalPublicKey(keyStr string) (*ecdsa.PublicKey, error) {
	encodedBlock, _ := pem.Decode([]byte(keyStr))
	if encodedBlock == nil {
		return nil, errors.New("invalid public key message body droped.")
	}
	key, err := x509.ParsePKIXPublicKey(encodedBlock.Bytes)
	if err != nil {
		return nil, err
	}
	return key.(*ecdsa.PublicKey), nil
}

func (s *redisPublicKeyStore) SavePublicKey(ctx context.Context, info *auth.KeyInfo) error {
	ttl := info.ExpireAt.Sub(time.Now())
	encodedKey, err := s.marshalPublicKey(info.Key)
	if err != nil {
		err = errors.WithMessage(err, "marshal public key failed")
		logger.Warnf("%v", err)
		return err
	}
	err = s.client.Set(ctx, s.buildKey(info.KeyId), encodedKey, ttl).Err()
	if err != nil {
		err = errors.WithMessage(err, "set public key to redis failed")
		logger.Warnf("%v", err)
		return err
	}
	ki := RedisPublicKeyInfo{
		KeyId:      info.KeyId,
		EncodedKey: encodedKey,
	}
	kiStr, _ := json.Marshal(ki)
	s.client.Publish(ctx, s.channel, kiStr)
	if err != nil {
		err = errors.WithMessage(err, "publish public key update failed, abort publish and ignore error")
		logger.Warnf("%v", err)
	}
	return nil
}

func (s *redisPublicKeyStore) LoadPublicKey(ctx context.Context, keyId string) *ecdsa.PublicKey {
	s.Lock()
	defer s.Unlock()
	if publicKey, ok := s.publicKeyCache[keyId]; ok {
		s.cacheMatchCount[keyId]++
		return &publicKey
	} else {
		val, err := s.client.Get(ctx, s.buildKey(keyId)).Result()
		if err != nil {
			logger.Warnf("%v", errors.WithMessage(err, fmt.Sprintf("get public key(%s) from redis failed", keyId)))
			return nil
		}
		key, err := s.unmarshalPublicKey(val)
		if err != nil {
			logger.Warnf("%v", errors.WithMessage(err, fmt.Sprintf("unmarshal public key(%s,%s) failed", keyId, val)))
			return nil
		}
		s.cacheMatchCount[keyId]++
		return key
	}
}

func (s *redisPublicKeyStore) CleanUnusedKey() {
	s.Lock()
	defer s.Unlock()
	for keyId, count := range s.cacheMatchCount {
		if count == 0 {
			delete(s.publicKeyCache, keyId)
		}
	}
	s.cacheMatchCount = make(map[string]int64)
}
